Skip to content

[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056

Open
KshitijLakhani wants to merge 15 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/enable-headdim256-bwd-sm100
Open

[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
KshitijLakhani wants to merge 15 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/feat/enable-headdim256-bwd-sm100

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 28, 2026

Description

Support for D=256 BWD for Blackwell CC 10x via the C++ API (which TE uses) was added in cuDNN 9.23 + cuDNN FE 1.24. Enabling this support in TE attention

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Add guard when picking the backend (sub backend) in TE common.
Add tests for D=256 case in TE PyT and TE JAX

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani force-pushed the klakhani/feat/enable-headdim256-bwd-sm100 branch from 51ad582 to d177ecf Compare May 28, 2026 23:05
@KshitijLakhani KshitijLakhani changed the title [JAX] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x [JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x May 28, 2026
@KshitijLakhani KshitijLakhani marked this pull request as ready for review May 29, 2026 22:56
@KshitijLakhani KshitijLakhani requested a review from cyanguwa as a code owner May 29, 2026 22:56
@KshitijLakhani KshitijLakhani self-assigned this May 29, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 29, 2026

Greptile Summary

This PR enables D=256 backward (bprop) support for Blackwell SM10x GPUs via the cuDNN deterministic SDPA bprop kernel introduced in cuDNN 9.23 / FE 1.24. A new guard is added to the nvte_get_fused_attn_backend selection logic and corresponding skip logic is added to both the JAX and PyTorch test suites.

  • C++ backend gate (fused_attn.cpp): A new sub-condition activates NVTE_F16_arbitrary_seqlen for d_qk == d_v == 256, SM10x, cuDNN ≥ 9.23, non-paged layout, and the deterministic kernel's restrictions (no bias, no dropout, vanilla softmax, window-size-compatible mask).
  • JAX tests: _check_configs gains matching skip guards for D=256 on SM10x; two new test params are added — a BSHD happy-path and a THD strict-xfail case that documents the known cuDNN 9.23 plan-build failure for THD layouts.
  • PyTorch tests: A new test_dpa_fused_attn_hdim256 test covers no-mask, padding, causal-SWA, and GQA configurations, guarded by @pytest.mark.skipif for cuDNN 9.23+ and SM100/SM103 devices.

Confidence Score: 4/5

Safe to merge for non-THD users; the known THD + D=256 cuDNN plan-build failure is documented with a strict xfail test rather than a backend exclusion, leaving a latent hard exception for production THD + D=256 training on SM10x.

The C++ backend gate correctly activates the new kernel path for BSHD/SBHD layouts and the test suite validates the main happy-path configurations. The THD layout issue remains unguarded in the backend selector, leaving a latent hard runtime exception for production THD + D=256 training workloads on SM10x.

transformer_engine/common/fused_attn/fused_attn.cpp — the new D=256 guard does not exclude THD-format layouts, which cannot build a cuDNN 9.23 execution plan.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds the D=256 BWD backend-selection guard for SM10x + cuDNN 9.23+; correctly restricts to no-bias/no-dropout/vanilla-softmax/non-paged with matching window-size constraints. THD-layout handling is documented via an xfail test rather than an explicit backend exclusion.
tests/jax/test_fused_attn.py Adds D=256 skip guards in _check_configs mirroring the C++ gate, refactors compute_capability/cudnn_version lookups, and adds a BSHD happy-path test plus a strict-xfail THD case for the known cuDNN 9.23 plan-build failure.
tests/pytorch/attention/test_attention.py Adds test_dpa_fused_attn_hdim256 covering no-mask, padding, causal-SWA, and GQA configurations, guarded by cuDNN 9.23+ and SM100/SM103 skipif decorators.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_get_fused_attn_backend()"] --> B{dtype FP16/BF16?}
    B -->|yes| C{flag_arb conditions}
    B -->|no| Z[Other backends / FP8]
    C --> D{arch & version gate}
    D -->|SM80/90 paths| D2[Earlier gates]
    D -->|cuDNN ge 9.7 and SM ge 100| E{head_dim check}
    E --> E1["d le 128 always"]
    E --> E2["d le 256 + Hopper cuDNN ge 9.1/9.5"]
    E --> E3["any d + Blackwell fprop cuDNN ge 9.9"]
    E --> E4["d_qk=192 d_v=128 + Blackwell bprop cuDNN ge 9.11"]
    E --> E5{"NEW: d_qk=d_v=256 + SM10x bprop cuDNN ge 9.23"}
    E5 --> G{no_bias and no_dropout and vanilla_softmax and non-paged and window_size OK?}
    G -->|no| F[Fall through]
    G -->|yes| H[flag_arb = true]
    H --> I{outer mask/format/SWA checks}
    I -->|fail| F
    I -->|pass| J[Return NVTE_F16_arbitrary_seqlen]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +382 to +383
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated comment fragment

The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.

Suggested change
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.
# (for non-causal masks) full-window attention.
# vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some editing glitch :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 379242b

Comment thread tests/jax/test_fused_attn.py Outdated
Comment on lines +465 to +475
# Non-learnable bias is fine (bias is allowed as an input); only dBias is
# unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s]
# (see test_backward), so gate on that.
unsupported = None
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
unsupported = "pre-scale bias"
elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
unsupported = (
"bias gradients (dBias); frozen/non-learnable bias inputs"
" (i.e. non-1HSS bias shapes) are supported"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 JAX skip logic diverges from C++ backend gate for non-1HSS bias

The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in b8fe919

attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0)))) ||
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this new feature support BSHD/SBHD and THD? It looks like the tests are focused on BSHD/SBHD only.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: THD support
I did test BSHD and BSHD+CP and it did pass on the JAX side and the CI for the PyT side did not fail either so I think that works.
My testing revealed that THD support is not yet available (Bwd plan compialtion issue) so I've filed a bug and shared a reproducer for the same with the cuDNN team: NVIDIA/cudnn-frontend#276

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these changes be moved to before "\ bias type" just so it's following an increasing order of the cuDNN version?

Fixed in a264de1

"D=256 BWD on Blackwell only supports right window -1 or 0"
" for causal masks."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these checks duplicate to the checks we added on the C++ side? Would the call FusedAttnHelper().get_fused_attn_backend() give you the same gating effect?

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we are just interested in the gating effect, you are right. The get_fused_attn_backend() will return NVTE_No_Backend and then there's a catch-all at the end which basically skip the tests as there is no fused attn backend avalable.

However, the reason for this to be here is to give a meaningful reason as to why a test is being skipped as compared to a generic "Unsupported inputs combination or device compute capability." message which does not qualify the reason for the skip. Unfortunately, on the JAX attn side we do not log the reason for disabling fused attn in the feature code like we have on the Pytorch side in d_p_a/utils.py. So there is no way for the user to know why the test was skipped. Hence, we need to rely on test code to log this on the JAX side.

I'd suggest we leave this in here for now. And when your PR for generating log messages in the C++ level when selecting the attn backend is ready, I can plumb it through onto the JAX side and then as part of that clean up, get rid of all the skip messages in check_configs()

KshitijLakhani and others added 11 commits June 3, 2026 10:49
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…n fused attn

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feat/enable-headdim256-bwd-sm100 branch from e317f99 to e08e9e8 Compare June 3, 2026 21:48
Comment on lines +318 to +332
// 9.23: d_qk = d_v = 256 + SM10x (cuDNN FE 1.24 / BE 9.23+) + bprop + non-paged
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
sm_arch_ < 110 && cudnn_runtime_version >= 92300 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
// The FE forces this path onto the deterministic bprop algorithm, which on
// Blackwell rejects dBias, dropout, and ALiBi (and supports vanilla softmax only).
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX &&
// Non-causal D=256 supports only full-window attention; SWA is allowed only for causal masks.
((window_size_left == -1 && window_size_right == -1) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
(window_size_right == -1 || window_size_right == 0)))) ||
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 D=256 BWD condition includes THD layout, causing a hard runtime exception

The new condition only excludes NVTE_Paged_KV_HD_HD_HD but does not exclude THD-format layouts. NVTE_THD_T2HD maps to layout_group = NVTE_HD_2HD and qkv_format = NVTE_THD, both of which pass all guards here and in the outer flag_arb checks (the qkv_format check at line 417 allows THD when sm_arch_ >= 90, which is true for SM10x). So nvte_get_fused_attn_backend returns NVTE_F16_arbitrary_seqlen for full-window THD + D=256 + SM10x + cuDNN ≥ 9.23, claiming support — but cuDNN 9.23 fails to build an execution plan for this layout, and NVTE_CHECK_CUDNN_FE on lines 421–422 of fused_attn_f16_arbitrary_seqlen.cu will throw a hard exception. The JAX xfail test documents the failure, but any production user with THD + D=256 training will hit an unrecoverable runtime error rather than a graceful backend fallback. Adding qkv_format != NVTE_QKV_Format::NVTE_THD to this condition would fix the backend selector; the JAX xfail test would then SKIP instead of XFAIL (which could be separately handled if you want to preserve the sentinel behaviour).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This support is forwarding looking, i.e., support is added for THD and BSHD in the TE common fused attn backend checking code, however, the PR is still waiting on cuDNN to fix support for THD.
The current PR will not be merged as is. One of two things will happen:

  1. cuDNN will fix THD support and only then will this PR be merged (most likely) - after fixing the XFAIL for THD cases to skips for a specific cuDNNv version
  2. cuDNN will not fix this soon in which case I will switch the support to BSHD only prior to merging this PR

KshitijLakhani and others added 4 commits June 4, 2026 17:23
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…ersions

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants